#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''=================================================
@Project -> File   ：try -> 乘法层
@IDE    ：PyCharm
@Author ：csl_forever
@Date   ：2020/9/17 19:31
@Desc   ：
=================================================='''


class MulLayer:
    def __init__(self):
        self.x = None
        self.y = None

    def forward(self, x, y):
        self.x = x
        self.y = y
        out = x * y

        return out

    def backward(self, dout):
        dx = dout * self.y
        dy = dout * self.x

        return dx, dy


apple = 100
apple_num = 2
tax = 1.1

mul_apple_layer = MulLayer()
mul_tax_layer = MulLayer()

# forward
apple_prize = mul_apple_layer.forward(apple, apple_num)
price = mul_tax_layer.forward(apple_prize, tax)

# backward
dprice = 1
dapple_price, dtax = mul_tax_layer.backward(dprice)
dapple, dapple_num = mul_apple_layer.backward(dapple_price)

print(price)
print(dapple, dapple_num,  dtax)
